import torch
from torch.utils.data import DataLoader
from torch_geometric.data import TemporalData
from typing import List

__all__ = ['TemporalDataLoader', 'ScheduledLoader', 'bipartite_reindexer']


def bipartite_reindexer(data):
    _, dst_idx = data.dst.unique(return_inverse=True)
    dst_idx += 1
    _, src_idx = data.src.unique(return_inverse=True)
    src_idx += 1
    # src_idx += dst_idx.max() + 1
    data.dst = dst_idx
    data.src = src_idx
    return data


class ScheduledLoader(DataLoader):
    def __init__(
        self,
        data: TemporalData,
        schedule: list,
        batch_size: int = 1,
        neg_sampling_ratio: float = 0.0,
        **kwargs,
    ):
        # Remove for PyTorch Lightning:
        kwargs.pop('dataset', None)
        kwargs.pop('collate_fn', None)
        kwargs.pop('shuffle', None)

        self.data = data
        self.neg_sampling_ratio = neg_sampling_ratio

        if neg_sampling_ratio > 0:
            self.min_dst = int(data.dst.min())
            self.max_dst = int(data.dst.max())

        if kwargs.get('drop_last', False) and len(data) % batch_size != 0:
            arange = range(0, len(data) - batch_size, batch_size)
        else:
            arange = range(0, len(data), batch_size)

        arange = list(arange)
        del arange[1:(schedule // batch_size)]

        arange = [(arange[i], arange[i+1]) for i in range(len(arange)-1)]
        super().__init__(arange, 1, shuffle=False, collate_fn=self)

    def __call__(self, arange: List[int]) -> TemporalData:
        batch = self.data[arange[0][0]:arange[0][1]]

        n_ids = [batch.src, batch.dst]

        if self.neg_sampling_ratio > 0:
            batch.neg_dst = torch.randint(
                low=self.min_dst,
                high=self.max_dst + 1,
                size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ),
                dtype=batch.dst.dtype,
                device=batch.dst.device,
            )
            n_ids += [batch.neg_dst]

        batch.n_id = torch.cat(n_ids, dim=0).unique()

        return batch

    def schedule_batching(self, data, batch_size, **kwargs):
        schedule = []
        # TODO
        return schedule


class TemporalDataLoader(DataLoader):
    def __init__(
        self,
        data: TemporalData,
        batch_size: int = 1,
        neg_sampling_ratio: float = 0.0,
        **kwargs,
    ):
        # Remove for PyTorch Lightning:
        kwargs.pop('dataset', None)
        kwargs.pop('collate_fn', None)
        kwargs.pop('shuffle', None)

        self.data = data
        if batch_size > 0:
            self.events_per_batch = batch_size
        else:
            self.events_per_batch = len(data)
        self.neg_sampling_ratio = neg_sampling_ratio

        if neg_sampling_ratio > 0:
            self.min_dst = int(data.dst.min())
            self.max_dst = int(data.dst.max())

        if kwargs.get('drop_last', False) and len(data) % batch_size != 0:
            arange = range(0, len(data) - batch_size, batch_size)
        else:
            arange = range(0, len(data), batch_size)

        super().__init__(arange, 1, shuffle=False, collate_fn=self)

    def __call__(self, arange: List[int]) -> TemporalData:
        batch = self.data[arange[0]:arange[0] + self.events_per_batch]

        n_ids = [batch.src, batch.dst]

        if self.neg_sampling_ratio > 0:
            batch.neg_dst = torch.randint(
                low=self.min_dst,
                high=self.max_dst + 1,
                size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ),
                dtype=batch.dst.dtype,
                device=batch.dst.device,
            )
            n_ids += [batch.neg_dst]

        batch.n_id = torch.cat(n_ids, dim=0).unique()

        return batch
